Conversation
- Introduced `nvte_unswizzle_scaling_factors` to convert swizzled scaling factors back to row-major format. - Implemented `regs_unshuffle_with_bit_shifts` and `regs_unshuffle` for unshuffling operations in CUDA kernels. - Added `unswizzle_row_scaling_kernel_impl` and `unswizzle_col_scaling_kernel_impl` for handling unswizzling in row and column scaling respectively. These changes enhance the functionality of the swizzle module, enabling better handling of scaling factors in tensor operations. Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
These enhancements tests the changes introduced for unswizzling Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
- Introduced `compute_ref_unswizzle` to handle the conversion of swizzled scaling factors back to their original format. - Added `performTestUnswizzle1D` to validate the unswizzling process with various scaling modes. - Created `UnswizzleTestSuite` for comprehensive testing of unswizzling operations. Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
- Moved the definition of `swizzle_row_scaling_kernel` to a new location for better organization. - Ensured the kernel implementation is now properly defined and accessible for scaling operations in the swizzle module. Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
- Introduced `multi_tensor_unswizzle_scaling_factors` to convert swizzled scaling factors back to their original row-major format. - Implemented CUDA kernels for unswizzling in both row and column scaling, enhancing the swizzle module's functionality. - Updated the launch function to handle multiple tensor unswizzling operations efficiently. These changes improve the handling of scaling factors in tensor operations, ensuring better performance and organization within the swizzle module. Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds unswizzle support for MXFP8 and NVFP4 scaling factors, providing the inverse operation to the existing Key observations:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant nvte_unswizzle_scaling_factors
participant unswizzle_scaling_factors
participant unswizzle_scaling_kernel
participant unswizzle_row_impl as unswizzle_row_scaling_kernel_impl
participant unswizzle_col_impl as unswizzle_col_scaling_kernel_impl
Caller->>nvte_unswizzle_scaling_factors: input (swizzled), output (compact), stream
nvte_unswizzle_scaling_factors->>unswizzle_scaling_factors: convertNVTETensorCheck()
unswizzle_scaling_factors->>unswizzle_scaling_factors: validate scaling_mode, dtype, shapes
alt rowwise_unswizzle
unswizzle_scaling_factors->>unswizzle_scaling_kernel: launch<<<grid,block,slm,stream>>>
unswizzle_scaling_kernel->>unswizzle_row_impl: row_scaling=true
unswizzle_row_impl->>unswizzle_row_impl: load tiles to SLM
unswizzle_row_impl->>unswizzle_row_impl: regs_unshuffle()
unswizzle_row_impl->>unswizzle_row_impl: write compact output
else columnwise_unswizzle
unswizzle_scaling_factors->>unswizzle_scaling_kernel: launch<<<grid,block,slm,stream>>>
unswizzle_scaling_kernel->>unswizzle_col_impl: row_scaling=false
unswizzle_col_impl->>unswizzle_col_impl: load tiles to SLM
unswizzle_col_impl->>unswizzle_col_impl: regs_unshuffle_with_bit_shifts()
unswizzle_col_impl->>unswizzle_col_impl: write compact output
end
unswizzle_scaling_factors-->>Caller: compact scale_inv in output
Last reviewed commit: 621bc16 |
| if ((rowwise && columnwise) || !(rowwise || columnwise)){ | ||
| GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" + | ||
| std::to_string(SF_MODE_Y) + "is not implemented."; | ||
| } |
There was a problem hiding this comment.
Uninitialized variables used in skip message
When !(rowwise || columnwise) is true (neither flag is set), neither if (rowwise) nor if (columnwise) branch executes, leaving SF_MODE_X and SF_MODE_Y uninitialized. Passing them to std::to_string() is undefined behaviour.
The same issue exists in performTestSwizzleUnswizzleRoundtrip at line 297.
| if ((rowwise && columnwise) || !(rowwise || columnwise)){ | |
| GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" + | |
| std::to_string(SF_MODE_Y) + "is not implemented."; | |
| } | |
| if ((rowwise && columnwise) || !(rowwise || columnwise)){ | |
| GTEST_SKIP() << "TEST SKIPPED, The scaling mode is not implemented."; | |
| } |
| if ((rowwise && columnwise) || !(rowwise || columnwise)){ | ||
| GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" + | ||
| std::to_string(SF_MODE_Y) + "is not implemented."; | ||
| } |
There was a problem hiding this comment.
Uninitialized variables used in skip message (roundtrip test)
Same undefined-behaviour issue as in performTestUnswizzle1D — SF_MODE_X and SF_MODE_Y are uninitialized when !(rowwise || columnwise).
| if ((rowwise && columnwise) || !(rowwise || columnwise)){ | |
| GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" + | |
| std::to_string(SF_MODE_Y) + "is not implemented."; | |
| } | |
| if ((rowwise && columnwise) || !(rowwise || columnwise)){ | |
| GTEST_SKIP() << "TEST SKIPPED, The scaling mode is not implemented."; | |
| } |
|
|
||
| if ((rowwise && columnwise) || !(rowwise || columnwise)){ | ||
| GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" + | ||
| std::to_string(SF_MODE_Y) + "is not implemented."; |
There was a problem hiding this comment.
Missing space in skip message
The concatenated string produces "...32is not implemented." (no space before "is"). Add a leading space.
| std::to_string(SF_MODE_Y) + "is not implemented."; | |
| std::to_string(SF_MODE_Y) + " is not implemented."; |
|
|
||
| if ((rowwise && columnwise) || !(rowwise || columnwise)){ | ||
| GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" + | ||
| std::to_string(SF_MODE_Y) + "is not implemented."; |
There was a problem hiding this comment.
Missing space in skip message (roundtrip test)
Same missing space issue — produces "...32is not implemented." without a space.
| std::to_string(SF_MODE_Y) + "is not implemented."; | |
| std::to_string(SF_MODE_Y) + " is not implemented."; |
| for (int i = 0; i < kVectorSize; i++) regs[i] = new_regs[i]; | ||
| } | ||
| template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K> |
There was a problem hiding this comment.
Missing blank line after function definition
regs_unshuffle_with_bit_shifts ends and the next template declaration begins immediately (no blank line). Every other function pair in this file is separated by a blank line. Add one for consistency.
| for (int i = 0; i < kVectorSize; i++) regs[i] = new_regs[i]; | |
| } | |
| template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K> | |
| for (int i = 0; i < kVectorSize; i++) regs[i] = new_regs[i]; | |
| } | |
| template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K> |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| const bool rowwise_swizzle = all_has_data || all_nvfp4; | ||
| const bool columnwise_swizzle = all_has_columnwise_data && !all_nvfp4; |
There was a problem hiding this comment.
Misleading variable names in unswizzle function
rowwise_swizzle and columnwise_swizzle are declared inside multi_tensor_unswizzle_scaling_factors but refer to unswizzle operations, not swizzle. This can confuse future readers about the data-flow direction. Consider renaming to rowwise_unswizzle / columnwise_unswizzle to match the function's purpose.
| const bool rowwise_swizzle = all_has_data || all_nvfp4; | |
| const bool columnwise_swizzle = all_has_columnwise_data && !all_nvfp4; | |
| const bool rowwise_unswizzle = all_has_data || all_nvfp4; | |
| const bool columnwise_unswizzle = all_has_columnwise_data && !all_nvfp4; |
Description
This PR adds unswizzle support for scaling factors and extends the swizzle module so scaling tensors can be converted from GEMM-swizzled layout back to compact layout, including multi-tensor paths. It also adds round-trip and standalone tests to validate unswizzle correctness.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
transformer_engine/common/swizzle/swizzle.cuand declarations intransformer_engine/common/include/transformer_engine/swizzle.htests/cpp/operator/test_swizzle.cu, including standalone unswizzle and swizzle→unswizzle round-trip coverageChecklist: